
# GAMMs and prediction plots ---- 

# In this script we are going to fit our GAMMs, and make plots of their predictions

#We are going to use the random effects structure suggested by Marianne - subject nested in family
#And add age and gender as covariates as main parameteric effects

#For consistency between measures we are using summary data (one average value per electrode per subject). Having considered
#using all the data e.g. spectral data per epoch, each spindle, I have decided against using this as some measures e.g. density 
#and coupling cannot be measured on this scale, which gives inconsistently sized data. In the interest of consistency therefore
#I conclude it is best to use summary data consistently across all measures.


# Set up =====

#Load packages
pacman::p_load(tidyverse, brms, tidybayes, bayestestR, 
               modelr, eegUtils, patchwork)

#Get our default settings
source("./eLife Submission Scripts/Analysis-Common-Utilities.R")


# Load Datasets =====

d_gamm = read_rds("./eLife Submission Data/sleep_study_eeg_summary_data_z.rds")

#Remove the angle data - angle models are fit differently - see below
d_gamm = 
  d_gamm |>
  filter(measure != "itpc_overlap_angle")

# Fit GAMMs =====

## Fit Models =====

# Set up our model

#this is the model formula, including a 2D smooth by group to account for the 60 electrodes, 
#and demographic factors - age@the time of EEG, gender and of course, genotype, in addition to
#varying intercepts for family, and subject nested within family

bf_gamm = bf(value ~ s(x,y,by = group, bs = "tp", k = 20, m = 2) + 
               age_eeg + gender + group + (1|family) + (1|family:subject))

#Check which prior settings the model requires
get_prior(data    = d_gamm$data[[1]], 
          family  = gaussian(link = "identity"),
          formula = bf_gamm)

#We use regularising priors as described by McElreath 

#Fit a GAMM to all datasets: TAKES A VERY LONG TIME
d_gamm = 
  d_gamm %>%
  mutate(model = map(data, ~  brm(data    = .x, 
                                  family  = gaussian(link = "identity"),
                                  formula = bf_gamm,
                                  prior = c(
                                    prior(normal(0, 1)  , class = b),
                                    prior(normal(0, 1)  , class = Intercept),
                                    prior(exponential(1), class = sigma),
                                    prior(exponential(1), class = sd),
                                    prior(exponential(1), class = sds)),
                                  iter = 4000, warmup = 1000, chains = 4, cores = 4,
                                  seed = 14,
                                  sample_prior = F,
                                  control = list(adapt_delta = 0.99))))

#Save our GAMM data immediately as it takes a long time to compute
write_rds(d_gamm,"./Data/bayesian_gamms.rds")


## Extract posterior samples ====

#Read the dataset if we aren't doing this all in one go
# d_gamm = read_rds("./Data/bayesian_gamms.rds")

#Make a sampling grid for posterior predictions

#Set number of points in grid
grid_size = 51

#Make the grid
d_gamm_grid = 
  d_gamm$data[[1]] %>%
  modelr::data_grid(
            x       = seq(from = -0.5, to = 0.5, length.out = grid_size),
            y       = seq(from = -0.5, to = 0.5, length.out = grid_size),
            gender  = levels(gender),
            age_eeg = 0,
            group   = levels(group)) 


#Make a grid of predicted differences using the add_epred_draws function
d_gamm = 
  d_gamm %>%
  mutate(post_draws = map(model, ~add_epred_draws(newdata = d_gamm_grid, 
                                                  object = .x,
                                                  re_formula = NA,
                                                  ndraws = 2000) %>%
                            ungroup() %>%
                            dplyr::select(x,y,group,.draw,.epred) %>%
                            pivot_wider(names_from = group, values_from = .epred,values_fn = list) %>%
                            unnest(c(`22q`,Sib)) %>%
                            mutate(diff = `22q` - Sib) %>%
                            group_by(x,y) %>%
                            nest()%>%
                            mutate(pd    = map(data,~p_direction(.x) %>% filter(Parameter == "diff")),
                                   m_dif = map(data, median_hdi, diff,.width = 0.95)) %>%
                            unnest(c(pd,m_dif)) %>%
                            mutate(pv = pd_to_p(pd,direction = "two-sided"),
                                   pv = ifelse(pv < (0.05),1,0.1),
                                   exc_1 = ifelse(((.lower > 0 & .upper > 0) |(.lower < 0 & .upper < 0)),1,0))
                          ))


## Basic Topoplots =====

#Plot our predictions onto a topoplot of the head for quality control
d_gamm = 
  d_gamm %>%
  mutate(ptitle = paste(measure, stage, sep = " ")) %>%
  mutate(plots = map2(post_draws,ptitle, ~ .x %>% 
                        dplyr::select(x,y,diff,pv) %>% 
                        mutate(incircle = sqrt(x ^ 2 + y ^ 2) < circ_scale) %>%
                        filter(incircle) %>%
                        ggplot(aes(x = x, y = y, fill = diff)) +
                        geom_raster(aes(alpha = pv)) +
                        geom_mask(r = circ_scale, size = 1.5) +
                        geom_head(r = circ_scale,size = 1) +                   
                        scale_fill_distiller(palette = "RdBu", limits = c(-1,1),
                                             oob = scales::squish) +
                        coord_equal()+
                        theme_void() +
                        theme(legend.position = "bottom",
                              strip.text = element_text(colour = "grey20",size = 8,angle = -90))+
                        guides(alpha = "none")+
                        scale_alpha(range = c(0.1, 1)) +
                        labs(title = .y)
  ))

#Quick plot of our GAMM results
cowplot::plot_grid(plotlist = d_gamm$plots,ncol = 5)

## Save GAMMs =====

write_rds(d_gamm,"./Data/bayesian_gamms.rds")

## Save posterior data only =====

# We do this to generate data that can be quickly loaded for lotting, and used in the
#elife submission
d_post = 
  d_gamm |>
  dplyr::select(stage,measure,post_draws) |>
  unnest(post_draws) |>
  dplyr::select(-c(data,Parameter,.width,.point,.interval)) |>
  nest() |>
  rename(post_draws = data)

write_rds(d_post,"./elife Submission Data/sleep_study_topoplot_posterior_data.rds")


# Angle data ------


#In this section, we will fit a model to the phase angle of spindle-SO coupling, which is a little more complicated
#than the gaussian models above

#Load some additional libraries
pacman::p_load(circular,CircStats,colorspace)


#Load data

d_angle = 
  read_rds("./eLife Submission Data/sleep_study_eeg_summary_data_z.rds") |>
  filter(measure == "itpc_overlap_angle") |>
  unnest(data) |>
  ungroup()

#Convert age to a z score, convert angles to radians (from degrees)
d_angle2 = 
  d_angle |> 
  rename(itpc_overlap_angle = value) |>
  mutate(group = factor(group, c("Sib","22q"))) |>
  drop_na(itpc_overlap_angle) |>
  mutate(age_eeg  = zscore(age_eeg),
         itpc_overlap_angle = itpc_overlap_angle * pi/180)


## Plot angular data ---------

# Make a plot from the mean angle data 
p_angle = 
  d_angle2 |>
    group_by(group,electrode) |>
    summarise(e_k = est.kappa(itpc_overlap_angle),
              mrl = est.rho(itpc_overlap_angle),
              m_a = circ.mean(itpc_overlap_angle),
              c_d = circ.disp(itpc_overlap_angle) |> as_tibble()) %>%
    left_join(topo, by = "electrode") %>%
    ungroup() %>%
    mutate(m_a = map_dbl(m_a,pracma::rad2deg)) %>%
    ggplot(aes(x = x,
               y = y,
               z = m_a,
               fill = m_a,
               label = electrode)) +
    geom_topo(grid_res = 200,
              interp_limit = "head",
              chan_markers = "point",
              chan_size = 1,
              method = "gam",
              color = "black") +
    # scale_fill_distiller(palette = "RdBu") + 
    facet_wrap(~group, ncol = 2)+
    theme_void() + 
    coord_equal() + 
    labs(fill = "Angle",subtitle = "ITPC Overlap - Angle") +
    scale_fill_gradientn(colours=rainbow_hcl(100,l = 70), limits = c(-180,180), breaks = c(-180,-90,0,90,180))



## Circular GAMM  =====

#Set the model formula
bf_vm_gamm2 = bf(itpc_overlap_angle ~ 0 + group + s(x,y,by = group, bs = "tp", k = 20, m = 2) + age_eeg + gender + (1|family) + (1|family:subject),
                              kappa ~ 0 + group + s(x,y,by = group, bs = "tp", k = 20, m = 2) + age_eeg + gender + (1|family) + (1|family:subject))

#Check required priors
get_prior(data    = d_angle, 
          family = "von_mises",
          formula = bf_vm_gamm2)


#Fit the model. Note we are running this model for many more samples than the other models to improve fit
vm_gamm  =
  brm(data    = d_angle, 
      family  = "von_mises",
      formula = bf_vm_gamm2,
      prior = c(
        prior(normal(0, 1)  , class = b),
        prior(exponential(1), class = sd),
        prior(exponential(1), class = sds)),
      iter = 6000, warmup = 2000, chains = 4, cores = 4,
      seed = 1234,
      sample_prior = F,
      control = list(adapt_delta = 0.95))

#Print model summary 
summary(vm_gamm)

#Inspect model output
plot(vm_gamm)

#Look at our mean angles - converted to degrees
fixef(vm_gamm)[1:2,] |> 
  as_tibble(rownames = "Term") |> 
  dplyr::select(-Est.Error) |> 
  mutate(across(where(is.double), ~((.x*180)/pi)+180))

#Save our vm model 
d_vm_gamm = tibble(
  measure = "MRL_Angle",
  data = list(d_angle),
  model = list(vm_gamm)
)

#Save
write_rds(d_vm_gamm,"./Data/bayesian_circular_gamm.rds")

## Plot posterior =====

#Load
d_vm_gamm = read_rds("./Data/bayesian_circular_gamm.rds")



#Make a grid for plotting
grid_size = 51

vm_gamm_grid = 
  d_vm_gamm$data[[1]] %>%
  modelr::data_grid(x       = seq(from = -0.5, to = 0.5, length.out = grid_size),
                    y       = seq(from = -0.5, to = 0.5, length.out = grid_size),
                    gender  = levels(gender),
                    age_eeg = 0,
                    group   = levels(group))  %>%
  add_epred_draws(object = d_vm_gamm$model[[1]],re_formula = NA, ndraws = 1000)



#Plot our estimated phases with a circular colour scale
d_vm_gamm = 
  d_vm_gamm |>
  mutate(d_angle_model = list(
                              vm_gamm_grid %>%
                                ungroup() %>%
                                dplyr::select(x,y,group,.epred) %>%
                                group_by(x,y,group) %>%
                                summarise(mu = circ.mean(.epred)) %>%
                                ungroup() ) 
  )


p_angle_model = 
  d_vm_gamm$d_angle_model[[1]] %>%
  mutate(group = factor(group,levels = c("Sib","22q")),
         mu = map_dbl(mu,pracma::rad2deg),
         incircle = sqrt(x ^ 2 + y ^ 2) < circ_scale) %>%
  filter(incircle) %>%
  ggplot(aes(x = x, y = y, fill = mu)) +
  geom_raster() +
  geom_mask(r = circ_scale, size = 3) +
  geom_head(r = circ_scale,size = 2) +                   
  # scale_fill_gradientn(colours=rainbow_hcl(100,l = 70), limits = c(-pi,pi),
  #                      oob = scales::squish) +
  scale_fill_gradientn(colours=rainbow_hcl(100,l = 70), limits = c(-180,180),
                       breaks = c(-180,-90,0,90,180),
                       oob = scales::squish) +  
  coord_equal()+
  theme_void() +
  facet_wrap(~group,ncol = 2)


#Plot the observed and modelled angle data alongside each other to check if the model is
#making reasonable predictions (it should!)
p_angle|p_angle_model



#Now lets work out the modelled angle difference
d_vm_gamm = 
  d_vm_gamm |>
  mutate(d_angle_diff = list(vm_gamm_grid %>%
                               ungroup() %>%
                               dplyr::select(x,y,group,.draw,.epred) %>%
                               pivot_wider(names_from = group, values_from = .epred,values_fn = list) %>%
                               unnest(c(`22q`,Sib)) %>%
                               mutate(diff_c = pi - abs(pi - abs(`22q` - Sib) %% (2*pi))) %>%  #Difference in circular land
                               mutate(diff = `22q` - Sib) %>%  
                               group_by(x,y) %>%
                               nest() %>%
                               mutate(pd    = map(data,~p_direction(.x) %>% filter(Parameter == "diff")),
                                      m_dif = map(data, median_hdi, diff,.width = 0.95)) %>%
                               unnest(c(pd,m_dif)) %>%
                               mutate(pv = map_dbl(pd, pd_to_p,direction = "two-sided"),
                                      pv = map_dbl(pv,~ifelse(.x < 0.05,1,0.1)))
  ))

#Make a plot of predicted group differences
p_angle_diff_vm = 
  d_vm_gamm$d_angle_diff[[1]] %>%
  mutate(diff_deg = map_dbl(diff,pracma::rad2deg)) %>%
  dplyr::select(x,y,diff_deg,pv) %>% 
  mutate(incircle = sqrt(x ^ 2 + y ^ 2) < circ_scale) %>%
  filter(incircle) %>%
  ggplot(aes(x = x, y = y, fill = diff_deg)) +
  geom_raster(aes(alpha = pv)) +
  geom_mask(r = circ_scale, size = 3) +
  geom_head(r = circ_scale,size = 2) +                   #based on geom_path() and geom_curve()
  # scale_fill_distiller(palette = "RdBu", limits = c(-1,1),
  #                      oob = scales::squish) +
  scale_fill_distiller(palette = "RdBu", limits = c(-180,180),
                       oob = scales::squish) +  
  coord_equal()+
  theme_void() +
  # theme(legend.position = "none") +
  scale_alpha(range = c(0.1, 1))

#Make an assembled plot

## Save Data =====

(p_angle|p_angle_diff_vm) + plot_layout(widths = c(2,1))


#Save our model data
write_rds(d_vm_gamm,"./Data/bayesian_circular_gamm.rds")



##Save posterior samples ======

#Save data and prepare the topoplot-able data for the eLife submission
write_rds(d_vm_gamm |> 
            dplyr::select(d_angle_diff) |>
            unnest(d_angle_diff) |>
            dplyr::select(-c(Parameter,data,.width,.point,.interval)),
          "./elife Submission Data/sleep_study_topoplot_angle_data.rds")
